use metal::ComputeCommandEncoder;
use mpsgraph::CommandBuffer as MPSCommandBuffer;

use super::{
    ForwardPassState,
    encodable_with_state::{EncodableWithState, EncodingParameters},
};

pub struct EncoderResolver<'a> {
    command_buffer: &'a MPSCommandBuffer,
    encoder: Option<ComputeCommandEncoder>,
}

impl<'a> EncoderResolver<'a> {
    pub fn new(command_buffer: &'a MPSCommandBuffer) -> Self {
        Self {
            command_buffer,
            encoder: None,
        }
    }

    pub fn encode(
        &mut self,
        block: &dyn EncodableWithState,
        state: &mut ForwardPassState,
        parameters: &EncodingParameters,
    ) {
        if block.supports_shared_encoder() {
            if self.encoder.is_none() {
                self.encoder = Some(
                    self.command_buffer
                        .root_command_buffer()
                        .new_compute_command_encoder()
                        .to_owned(),
                );
            }
            block.encode_with_shared_encoder(
                state,
                self.encoder.as_ref().unwrap(),
                parameters,
            );
        } else {
            self.end_current_encoder();
            block.encode(state, self.command_buffer, parameters);
        }
    }

    pub fn end_current_encoder(&mut self) {
        if let Some(encoder) = self.encoder.take() {
            encoder.end_encoding();
        }
    }

    pub fn finish(mut self) {
        self.end_current_encoder();
    }
}
